# Copyright 2021 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

name: Train Model
inputs:
- {name: project_id, type: String}
- {name: data_region, type: String}
- {name: hptune_region, type: String}
- {name: data_pipeline_root, type: String}
- {name: input_data_schema, type: String}
- {name: training_container_image_uri, type: String}
- {name: serving_container_image_uri, type: String}
- {name: custom_job_service_account, type: String}
- {name: input_dataset, type: Dataset}
- {name: machine_type, type: String}
- {name: accelerator_type, type: String}
- {name: accelerator_count, type: Integer}
- {name: vpc_network, type: String, optional: true}
- {name: train_additional_args, type: String, optional: true}
- {name: output_model_file_name, type: String, optional: true}
- {name: hp_config_max_trials, type: Integer, optional: true}
- {name: hp_config_suggestions_per_request, type: Integer, optional: true}
outputs:
- {name: output_model, type: Model}
- {name: basic_metrics, type: Metrics}
- {name: classification_metrics, type: ClassificationMetrics}
- {name: feature_importance_dataset, type: Dataset}
- {name: instance_schema, type: Artifact}
implementation:
  container:
    image: {{af_registry_location}}-docker.pkg.dev/{{project_id}}/{{af_registry_name}}/component-base:latest
    command: [python, /pipelines/component/src/train.py]
    args: [
      --executor_input, {executorInput: null},
      --function_to_execute, train_model,
      --project-id, {inputValue: project_id},
      --data-region, {inputValue: data_region},
      --hptune-region, {inputValue: hptune_region},
      --data-pipeline-root, {inputValue: data_pipeline_root},
      --training-container-image-uri, {inputValue: training_container_image_uri},
      --serving-container-image-uri, {inputValue: serving_container_image_uri},
      --custom-job-service-account, {inputValue: custom_job_service_account},
      --machine-type, {inputValue: machine_type},
      --accelerator-type, {inputValue: accelerator_type},
      --accelerator-count, {inputValue: accelerator_count},
      --input-dataset, {inputPath: input_dataset},
      --input_data_schema, {inputValue: input_data_schema},
      --output-model, {outputPath: output_model},
      --basic-metrics, {outputPath: basic_metrics},
      --classification-metrics, {outputPath: classification_metrics},
      --feature-importance-dataset, {outputPath: feature_importance_dataset},
      --instance-schema-dataset, {outputPath: instance_schema},
      --vpc-network, {inputValue: vpc_network},
      --output_model_file_name, {inputValue: output_model_file_name},
      --train-additional-args, {inputValue: train_additional_args},
      --hp-config-suggestions-per-request, {inputValue: hp_config_suggestions_per_request},
      --hp-config-max-trials, {inputValue: hp_config_max_trials}
    ]
